import hnswlib
#import numpy as np
import sys
import time
#from sklearn.neighbors import NearestNeighbors

from util.graphdb_base import GraphDBBase

class DopamineParticipantSimilarities(GraphDBBase):

    def __init__(self, argv):
        super().__init__(command=__file__, argv=argv)

    def compute_and_store_similarities(self, k, distance_function, relationship_name, vector_type, threshold):
        start = time.time()
        data, participant_ids = self.get_dopamine_vectors(vector_type)
        print("Time to get vectors:", time.time() - start)
        start = time.time()

        ann_ids, ann_similarities = self.compute_ann(data, participant_ids, k, distance_function)

        print("Time to compute nearest neighbors:", time.time() - start)
        start = time.time()
        self.store_ann(participant_ids, ann_ids, ann_similarities, relationship_name, threshold)
        print("Time to store nearest neighbors:", time.time() - start)
        print("done")

    def get_dopamine_vectors(self, vector_type):
        
        if vector_type == "gene":
            list_of_transaction_query = """
                        MATCH (n:DopamineASD)
                        RETURN toInteger(n.participantId) as participantId, n.dopamineGeneDosageVector as vector
                        UNION MATCH (n:DopamineDD)
                        RETURN toInteger(n.participantId) as participantId, n.dopamineGeneDosageVector as vector
                    """
        elif vector_type == "go":
            list_of_transaction_query = """
                        MATCH (n:DopamineASD)
                        RETURN toInteger(n.participantId) as participantId, n.dopamineGoVector as vector
                        UNION MATCH (n:DopamineDD)
                        RETURN toInteger(n.participantId) as participantId, n.dopamineGoVector as vector
                    """
        elif vector_type == "genego":
            list_of_transaction_query = """
                        MATCH (n:DopamineASD)
                        RETURN toInteger(n.participantId) as participantId, (n.dopamineGeneDosageVector + n.dopamineGoVector) as vector
                        UNION MATCH (n:DopamineDD)
                        RETURN toInteger(n.participantId) as participantId, (n.dopamineGeneDosageVector + n.dopamineGoVector) as vector
                    """
        data = []
        participant_ids = []
        with self._driver.session() as session:
            i = 0
            for result in session.run(list_of_transaction_query):
                participant_id = result["participantId"]
                vector = result["vector"]

                data.append(vector)
                participant_ids.append(participant_id)
                i += 1
                if i % 10000 == 0:
                    print(i, "participants processed")
            print(i, "participants processed")
        return data, participant_ids

    def compute_ann(self, data, participant_ids, k, distance_function):
        dim = len(data[0])
        num_elements = len(participant_ids)
        # Declaring index
        p = hnswlib.Index(space=distance_function, dim=dim)  # possible options for ditance_formula are l2, cosine or ip
        # Initing index - the maximum number of elements should be known beforehand
        p.init_index(max_elements=num_elements, ef_construction=800, M=100, random_seed=42)
        # Element insertion (can be called several times):
        p.add_items(data, participant_ids)
        # Controlling the recall by setting ef:
        p.set_ef(800)  # ef should always be > k
        # Query dataset, k - number of closest elements (returns 2 numpy arrays)
        ids, distances = p.knn_query(data, k = k)
        return ids, distances

    def store_ann(self, participant_ids, ann_ids, ann_distances, label, threshold): #ADD the opportunity to specify the nsme of the relationship
        clean_query = """
            MATCH (participant:Participant)-[s:{}]->()
            WHERE participant.participantId = $participantId
            DELETE s
        """.format(label)
        
        # <> -> not equal operator. Avoides self-lops.
        query = """
            MATCH (participant:Participant)
            WHERE participant.participantId = $participantId
            UNWIND keys($knn) as otherParticipantId
            MATCH (other:Participant)
            WHERE other.participantId = otherParticipantId and other.participantId <> $participantId
            MERGE (participant)-[:{} {{weight: $knn[otherParticipantId]}}]->(other)
        """.format(label)

        with self._driver.session() as session:
            i = 0
            for participant_id in participant_ids:
                ann_ids_array = ann_ids[i]
                ann_distances_array = ann_distances[i]
                i += 1
                knnMap = {}
                j = 0
                for ann_label in ann_ids_array:
                    value = 1.0 - float(ann_distances_array[j]) # similarity = 1 - distance
                    if value > threshold:
                        knnMap[str(ann_label)] = value
                    j += 1
                #print(knnMap)
                tx = session.begin_transaction()
                tx.run(clean_query, {"participantId": str(participant_id)})
                tx.run(query, {"participantId": str(participant_id), "knn": knnMap})
                tx.commit()

                if i % 1000 == 0:
                    print(i, "participants processed")


if __name__ == '__main__':
    knn = DopamineParticipantSimilarities(sys.argv[1:])
    distance_formula_value = "cosine" #this formula value will compute a distance metric...
    k = 200
    threshold = 1e-03
    knn.compute_and_store_similarities(k, distance_formula_value, "DOPAMINE_GENETIC_SIMILARITY", "gene", threshold)
    knn.compute_and_store_similarities(k, distance_formula_value, "DOPAMINE_GO_SIMILARITY", "go", threshold)
    knn.compute_and_store_similarities(k, distance_formula_value, "DOPAMINE_GENETIC_GO_SIMILARITY", "genego", threshold)

    knn.close